# ComfyUI_WaveletColorfix/nodes.py
import torch
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from comfy.utils import ProgressBar

# Import the color correction functions from our local file
from .wavelet_color_fix import adain_color_fix, wavelet_color_fix

class WaveletColorFix:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "target_image": ("IMAGE",),
                "source_image": ("IMAGE",),
                "align_method": (["adain", "wavelet"],),
            }
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "process"
    CATEGORY = "⚪Lum3on/ColorFix"

    def process(self, target_image: torch.Tensor, source_image: torch.Tensor, align_method: str):
        num_target_frames = target_image.shape[0]
        num_source_frames = source_image.shape[0]

        if num_target_frames != num_source_frames:
            print(f"[WaveletColorFix Warning] Target frames ({num_target_frames}) and Source frames ({num_source_frames}) do not match.")
            print(f"[WaveletColorFix] Will process up to the minimum of the two: {min(num_target_frames, num_source_frames)} frames.")

        pbar = ProgressBar(num_target_frames)
        output_frames_pil = []
        _, target_height, target_width, _ = target_image.shape

        for i in tqdm(range(num_target_frames), desc="Wavelet Color Fixing"):
            # Convert current target frame tensor to PIL Image
            target_pil = transforms.ToPILImage()(target_image[i].permute(2, 0, 1).cpu().float())

            # Ensure a corresponding source frame is available
            if i < num_source_frames:
                source_pil = transforms.ToPILImage()(source_image[i].permute(2, 0, 1).cpu().float())
                # Resize source to match target dimensions for accurate color transfer
                source_pil_resized = source_pil.resize((target_width, target_height), Image.LANCZOS)
                
                if align_method == 'adain':
                    fixed_pil = adain_color_fix(target=target_pil, source=source_pil_resized)
                elif align_method == 'wavelet':
                    fixed_pil = wavelet_color_fix(target=target_pil, source=source_pil_resized)
                else:
                    # Fallback to the target if method is unknown (should not happen with a dropdown)
                    fixed_pil = target_pil
            else:
                # If no corresponding source frame, pass the target through without correction
                fixed_pil = target_pil

            output_frames_pil.append(fixed_pil)
            pbar.update(1)

        # Convert all processed PIL images back to a single tensor batch
        final_output_tensors = [transforms.ToTensor()(f.convert("RGB")).permute(1, 2, 0) for f in output_frames_pil]
        return (torch.stack(final_output_tensors),)